{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f4TSNCvpENrW"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "vamNSA0vEP-m"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "asd4sdga7g"
      },
      "source": [
        "# Classifying CIFAR-10 with XLA\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b7noD9NjFRL-"
      },
      "source": [
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/xla/tutorials/autoclustering_xla\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/tutorials/autoclustering_xla.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/tutorials/autoclustering_xla.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/tensorflow/tensorflow/compiler/xla/g3doc/tutorials/autoclustering_xla.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mz65veHXsmnS"
      },
      "source": [
        "This tutorial trains a TensorFlow model to classify the [CIFAR-10](https://en.wikipedia.org/wiki/CIFAR-10) dataset, and we compile it using XLA.\n",
        "\n",
        "Load and normalize the dataset using the [TensorFlow Datasets](https://tensorflow.org/datasets) API:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "R4xtYyOf78e3"
      },
      "outputs": [],
      "source": [
        "!pip install tensorflow_datasets"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PH2HbLW65tmo"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "import tensorflow_datasets as tfds"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7vm2QsMisCxI"
      },
      "outputs": [],
      "source": [
        "# Check that GPU is available: cf. https://colab.research.google.com/notebooks/gpu.ipynb\n",
        "assert(tf.test.gpu_device_name())\n",
        "\n",
        "tf.keras.backend.clear_session()\n",
        "tf.config.optimizer.set_jit(False) # Start with XLA disabled.\n",
        "\n",
        "def load_data():\n",
        "  result = tfds.load('cifar10', batch_size = -1)\n",
        "  (x_train, y_train) = result['train']['image'],result['train']['label']\n",
        "  (x_test, y_test) = result['test']['image'],result['test']['label']\n",
        "  \n",
        "  x_train = x_train.numpy().astype('float32') / 256\n",
        "  x_test = x_test.numpy().astype('float32') / 256\n",
        "\n",
        "  # Convert class vectors to binary class matrices.\n",
        "  y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)\n",
        "  y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)\n",
        "  return ((x_train, y_train), (x_test, y_test))\n",
        "\n",
        "(x_train, y_train), (x_test, y_test) = load_data()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MgNM2tbgtScx"
      },
      "source": [
        "We define the model, adapted from the Keras [CIFAR-10 example](https://keras.io/examples/cifar10_cnn/):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3ZRQSwoRsKM_"
      },
      "outputs": [],
      "source": [
        "def generate_model():\n",
        "  return tf.keras.models.Sequential([\n",
        "    tf.keras.layers.Conv2D(32, (3, 3), padding='same', input_shape=x_train.shape[1:]),\n",
        "    tf.keras.layers.Activation('relu'),\n",
        "    tf.keras.layers.Conv2D(32, (3, 3)),\n",
        "    tf.keras.layers.Activation('relu'),\n",
        "    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),\n",
        "    tf.keras.layers.Dropout(0.25),\n",
        "\n",
        "    tf.keras.layers.Conv2D(64, (3, 3), padding='same'),\n",
        "    tf.keras.layers.Activation('relu'),\n",
        "    tf.keras.layers.Conv2D(64, (3, 3)),\n",
        "    tf.keras.layers.Activation('relu'),\n",
        "    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),\n",
        "    tf.keras.layers.Dropout(0.25),\n",
        "\n",
        "    tf.keras.layers.Flatten(),\n",
        "    tf.keras.layers.Dense(512),\n",
        "    tf.keras.layers.Activation('relu'),\n",
        "    tf.keras.layers.Dropout(0.5),\n",
        "    tf.keras.layers.Dense(10),\n",
        "    tf.keras.layers.Activation('softmax')\n",
        "  ])\n",
        "\n",
        "model = generate_model()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-M4GtGDZtb8a"
      },
      "source": [
        "We train the model using the\n",
        "[RMSprop](https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer)\n",
        "optimizer:\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UKCmrhF0tiMa"
      },
      "outputs": [],
      "source": [
        "def compile_model(model):\n",
        "  opt = tf.keras.optimizers.RMSprop(learning_rate=0.0001, decay=1e-6)\n",
        "  model.compile(loss='categorical_crossentropy',\n",
        "                optimizer=opt,\n",
        "                metrics=['accuracy'])\n",
        "  return model\n",
        "\n",
        "model = compile_model(model)\n",
        "\n",
        "def train_model(model, x_train, y_train, x_test, y_test, epochs=25):\n",
        "  model.fit(x_train, y_train, batch_size=256, epochs=epochs, validation_data=(x_test, y_test), shuffle=True)\n",
        "\n",
        "def warmup(model, x_train, y_train, x_test, y_test):\n",
        "  # Warm up the JIT, we do not wish to measure the compilation time.\n",
        "  initial_weights = model.get_weights()\n",
        "  train_model(model, x_train, y_train, x_test, y_test, epochs=1)\n",
        "  model.set_weights(initial_weights)\n",
        "\n",
        "warmup(model, x_train, y_train, x_test, y_test)\n",
        "%time train_model(model, x_train, y_train, x_test, y_test)\n",
        "\n",
        "scores = model.evaluate(x_test, y_test, verbose=1)\n",
        "print('Test loss:', scores[0])\n",
        "print('Test accuracy:', scores[1])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SLpfQ0StRgsu"
      },
      "source": [
        "Now let's train the model again, using the XLA compiler.\n",
        "To enable the compiler in the middle of the application, we need to reset the Keras session."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jxU-Tzy4SX7p"
      },
      "outputs": [],
      "source": [
        "# We need to clear the session to enable JIT in the middle of the program.\n",
        "tf.keras.backend.clear_session()\n",
        "tf.config.optimizer.set_jit(True) # Enable XLA.\n",
        "model = compile_model(generate_model())\n",
        "(x_train, y_train), (x_test, y_test) = load_data()\n",
        "\n",
        "warmup(model, x_train, y_train, x_test, y_test)\n",
        "%time train_model(model, x_train, y_train, x_test, y_test)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iWHz6P1se92F"
      },
      "source": [
        "On a machine with a Titan V GPU and an Intel Xeon E5-2690 CPU the speed up is ~1.17x."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "CIFAR-10 with XLA.ipynb",
      "private_outputs": true,
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
